import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

class SkeletonConv(nn.Module):
    def __init__(self, neighbour_list, in_channels, out_channels, kernel_size, joint_num,
                 stride=1, padding=0, bias=True, padding_mode="zeros"):
        super().__init__()
        self.in_channels_per_joint = in_channels // joint_num
        self.out_channels_per_joint = out_channels // joint_num
        if in_channels % joint_num != 0 or out_channels % joint_num != 0:
            raise Exception("BAD")

        if padding_mode == "zeros": padding_mode = "constant"
        if padding_mode == "reflection": padding_mode = "reflect"

        self.expand_neighbour_list = []
        self.neighbour_list = neighbour_list
        self.joint_num = joint_num

        self.stride = stride
        self.dilation = 1
        self.groups = 1
        self.padding = padding
        self.padding_mode = padding_mode
        self._padding_repeated_twice = (padding, padding)

        for neighbour in neighbour_list:
            expanded = []
            for k in neighbour:
                for i in range(self.in_channels_per_joint):
                    expanded.append(k * self.in_channels_per_joint+i)
            self.expand_neighbour_list.append(expanded)
        self.weight = torch.zeros(out_channels, in_channels, kernel_size)
        if bias:
            self.bias = torch.zeros(out_channels)
        else:
            # self.bias_mat = None
            self.register_parameter('bias', None)
        self.mask = torch.zeros_like(self.weight)
        for i, neighbour in enumerate(self.expand_neighbour_list):
            self.mask[self.out_channels_per_joint * i :
                      self.out_channels_per_joint * (i + 1),
                      neighbour, ...] = 1
        self.parameter_cnt = self.mask.sum() + (out_channels if bias else 0)
        self.mask = nn.Parameter(self.mask, requires_grad=False)
        self.description = 'SkeletonConv(in_channels_per_armature={}, out_channels_per_armature={}, kernel_size={}, ' \
                           'joint_num={}, stride={}, padding={}, bias={})\n' \
                           'Total Learnable Parameters={}'.format(
            in_channels // joint_num, out_channels // joint_num, kernel_size,
            joint_num, stride, padding, bias, self.parameter_cnt
        )
        self.reset_parameters()

    def reset_parameters(self):
        # self.weight = self.weight_mat.clone()
        # if self.bias_mat is not None:
        #     self.bias = self.bias_mat.clone()

        for i, neighbour in enumerate(self.expand_neighbour_list):
            """ Use temporary variable to avoid assign to copy of slice, which might lead to un expected result """
            tmp = torch.zeros_like(self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i+1),
                                   neighbour, ...])
            nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
            # print(tmp.shape)
            self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i+1),
                                   neighbour, ...] = tmp
            if self.bias is not None:
                fan_in, _ = nn.init._calculate_fan_in_and_fan_out(
                    self.weight[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1), neighbour, ...])
                bound = 1 / math.sqrt(fan_in)
                tmp = torch.zeros_like(
                    self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)])
                nn.init.uniform_(tmp, -bound, bound)
                self.bias[self.out_channels_per_joint * i: self.out_channels_per_joint * (i + 1)] = tmp
        self.weight = nn.Parameter(self.weight)
        if self.bias is not None:
            self.bias = nn.Parameter(self.bias)

    def forward(self, input):
        weight_masked = self.weight * self.mask
        # print(input.shape)
        # print(self.description)
        res = F.conv1d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
                       weight_masked, self.bias, self.stride,
                       0, self.dilation, self.groups)
        # print(res.shape)
        return res

    def __repr__(self):
        return self.description


class SkeletonLinear(nn.Module):
    def __init__(self, neighbour_list, in_channels, out_channels):
        super().__init__()
        self.neighbour_list = neighbour_list
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.in_channels_per_joint = in_channels // len(neighbour_list)
        self.out_channels_per_joint = out_channels // len(neighbour_list)
        self.expanded_neighbour_list = []
        joint_num = len(neighbour_list)
        # self.extra_dim1 = extra_dim1
        for neighbour in neighbour_list:
            expanded = []
            for k in neighbour:
                for i in range(self.in_channels_per_joint):
                    expanded.append(k * self.in_channels_per_joint + i)
            self.expanded_neighbour_list.append(expanded)
        self.weight = torch.zeros(out_channels, in_channels)
        self.mask = torch.zeros(out_channels, in_channels)
        self.bias = nn.Parameter(torch.Tensor(out_channels))
        self.reset_parameters()
        self.parameter_cnt = self.mask.sum() + out_channels
        self.description = 'SkeletonLinear(in_channels_per_armature={}, out_channels_per_armature={}, ' \
                           'joint_num={})\n' \
                           'Total Learnable Parameters={}'.format(
            in_channels // joint_num, out_channels // joint_num,
            joint_num, self.parameter_cnt
        )

    def reset_parameters(self):
        for i, neighbour in enumerate(self.expanded_neighbour_list):
            tmp = torch.zeros_like(
                self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour]
            )
            self.mask[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = 1
            nn.init.kaiming_uniform_(tmp, a=math.sqrt(5))
            self.weight[i*self.out_channels_per_joint: (i + 1)*self.out_channels_per_joint, neighbour] = tmp

        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.bias, -bound, bound)

        self.weight = nn.Parameter(self.weight)
        self.mask = nn.Parameter(self.mask, requires_grad=False)

    def forward(self, input):
        # print(input.shape)
        if len(input.shape) == 3:
            B, S, D = input.shape
            out_shape = (B, S)
            input = input.view(B * S, D)
        else:
            B, D = input.shape
            out_shape = (B,)
        weight_masked = self.weight * self.mask
        res = F.linear(input, weight_masked, self.bias)
        # print(res.shape)
        # print(out_shape + res.shape[-1])
        return res.view(out_shape + (res.shape[-1],))

    def __repr__(self):
        return self.description

class SkeletonPoolJoint(nn.Module):
    def __init__(self, topology, channels_per_joint, last_pool=False):
        '''
        :topology: parent structure of skeleton, which is a list
        :last_pool: if ture, pooling each kinematic chain into one vector. Otherwise, pool each two joint
        where the joints and their children (if exist) have the less than two child
        '''
        super(SkeletonPoolJoint, self).__init__()

        self.joint_num = len(topology)
        self.parent = topology
        self.pooling_list = []
        self.pooling_map = [-1 for _ in range(len(self.parent))]
        self.child = [-1 for _ in range(len(self.parent))]
        children_cnt = [0 for _ in range(len(self.parent))]
        for x, pa in enumerate(self.parent):
            if pa < 0: continue
            children_cnt[pa] += 1
            self.child[pa] = x
        self.pooling_map[0] = 0
        for x in range(len(self.parent)):
            if children_cnt[x] == 0 or (children_cnt[x] == 1 and children_cnt[self.child[x]] > 1):
                while children_cnt[x] <= 1:
                    pa = self.parent[x]
                    if last_pool:
                        seq = [x]
                        while pa != -1 and children_cnt[pa] == 1:
                            seq = [pa] + seq
                            x = pa
                            pa = self.parent[x]
                        self.pooling_list.append(seq)
                        break
                    else:
                        if pa != -1 and children_cnt[pa] == 1:
                            self.pooling_list.append([pa, x])
                            x = self.parent[pa]
                        else:
                            self.pooling_list.append([x, ])
                            break
            elif children_cnt[x] > 1:
                self.pooling_list.append([x, ])

        self.description = 'SkeletonPool(in_joint_num={}, out_joint_num={})'.format(
            len(topology), len(self.pooling_list),
        )

        self.pooling_list.sort(key=lambda x:x[0])
        for i, a in enumerate(self.pooling_list):
            for j in a:
                self.pooling_map[j] = i

        self.output_joint_num = len(self.pooling_list)
        self.new_topology = [-1 for _ in range(len(self.pooling_list))]
        for i, x in enumerate(self.pooling_list):
            if i < 1: continue
            self.new_topology[i] = self.pooling_map[self.parent[x[0]]]

        self.weight = torch.zeros(len(self.pooling_list) * channels_per_joint, self.joint_num * channels_per_joint)

        for i, pair in enumerate(self.pooling_list):
            for j in pair:
                for c in range(channels_per_joint):
                    self.weight[i * channels_per_joint + c, j * channels_per_joint + c] = 1.0 / len(pair)

        self.weight = nn.Parameter(self.weight, requires_grad=False)

    def forward(self, input: torch.Tensor):
        # print(input.shape)
        # print(self.weight.shape)
        return torch.matmul(self.weight, input)

    def __repr__(self):
        return self.description


class SkeletonPool(nn.Module):
    def __init__(self, edges, pooling_mode, channels_per_edge, last_pool=False):
        super(SkeletonPool, self).__init__()

        if pooling_mode != 'mean':
            raise Exception('Unimplemented pooling mode in matrix_implementation')

        self.channels_per_edge = channels_per_edge
        self.pooling_mode = pooling_mode
        self.edge_num = len(edges) + 1
        self.seq_list = []
        self.pooling_list = []
        self.new_edges = []
        degree = [0] * 100

        for edge in edges:
            degree[edge[0]] += 1
            degree[edge[1]] += 1

        def find_seq(j, seq):
            nonlocal self, degree, edges

            if degree[j] > 2 and j != 0:
                self.seq_list.append(seq)
                seq = []

            if degree[j] == 1:
                self.seq_list.append(seq)
                return

            for idx, edge in enumerate(edges):
                if edge[0] == j:
                    find_seq(edge[1], seq + [idx])

        find_seq(0, [])
        for seq in self.seq_list:
            if last_pool:
                self.pooling_list.append(seq)
                continue
            if len(seq) % 2 == 1:
                self.pooling_list.append([seq[0]])
                self.new_edges.append(edges[seq[0]])
                seq = seq[1:]
            for i in range(0, len(seq), 2):
                self.pooling_list.append([seq[i], seq[i + 1]])
                self.new_edges.append([edges[seq[i]][0], edges[seq[i + 1]][1]])

        # add global position
        self.pooling_list.append([self.edge_num - 1])

        self.description = 'SkeletonPool(in_edge_num={}, out_edge_num={})'.format(
            len(edges), len(self.pooling_list)
        )

        self.weight = torch.zeros(len(self.pooling_list) * channels_per_edge, self.edge_num * channels_per_edge)

        for i, pair in enumerate(self.pooling_list):
            for j in pair:
                for c in range(channels_per_edge):
                    self.weight[i * channels_per_edge + c, j * channels_per_edge + c] = 1.0 / len(pair)

        self.weight = nn.Parameter(self.weight, requires_grad=False)

    def forward(self, input: torch.Tensor):
        return torch.matmul(self.weight, input)

    def __repr__(self):
        return self.description


class SkeletonUnpool(nn.Module):
    def __init__(self, pooling_list, channels_per_edge):
        super(SkeletonUnpool, self).__init__()
        self.pooling_list = pooling_list
        self.input_joint_num = len(pooling_list)
        self.output_joint_num = 0
        self.channels_per_edge = channels_per_edge
        for t in self.pooling_list:
            self.output_joint_num += len(t)

        self.description = 'SkeletonUnpool(in_joint_num={}, out_joint_num={})'.format(
            self.input_joint_num, self.output_joint_num,
        )

        self.weight = torch.zeros(self.output_joint_num * channels_per_edge, self.input_joint_num * channels_per_edge)

        for i, pair in enumerate(self.pooling_list):
            for j in pair:
                for c in range(channels_per_edge):
                    self.weight[j * channels_per_edge + c, i * channels_per_edge + c] = 1

        self.weight = nn.Parameter(self.weight)
        self.weight.requires_grad_(False)

    def forward(self, input: torch.Tensor):
        # print(input.shape)
        # print(self.weight.shape)
        if len(input.shape) == 2:
            input = input.unsqueeze(-1)
        return torch.matmul(self.weight, input).squeeze()

    def __repr__(self):
        return self.description


def find_neighbor_joint(parents, threshold):
    n_joint = len(parents)
    dist_mat = np.empty((n_joint, n_joint), dtype=np.int)
    dist_mat[:, :] = 100000
    for i, p in enumerate(parents):
        dist_mat[i, i] = 0
        if i != 0:
            dist_mat[i, p] = dist_mat[p, i] = 1

    """
    Floyd's algorithm
    """
    for k in range(n_joint):
        for i in range(n_joint):
            for j in range(n_joint):
                dist_mat[i, j] = min(dist_mat[i, j], dist_mat[i, k] + dist_mat[k, j])

    neighbor_list = []
    for i in range(n_joint):
        neighbor = []
        for j in range(n_joint):
            if dist_mat[i, j] <= threshold:
                neighbor.append(j)
        neighbor_list.append(neighbor)

    return neighbor_list